import numpy as np
from pathlib import Path
import pandas as pd
import numpy as np
import glob
import matplotlib.pyplot as plt
from multiprocessing import Pool
import seaborn as sns
import math
from matplotlib import ticker


def get_average_from_directory(dir: str, radio_programs=None, window_length_in_minutes=15, offset_hours=6, end_hour=-1, fontsize=-1, label_fontsize= 10, figure_size=(18,10)):
    files = sorted(glob.glob(f'{dir}/**/*.txt',recursive=True))    
    results = []
    results_dict = dict()
    plots = dict()
    line_plot_args_list = dict()

    for file in files:
        print(f'File: {file}')
        file_results, fig, line_plot_args = plot_time_span_file(file, window_length_in_minutes=window_length_in_minutes, radio_programs=radio_programs, end_hour=end_hour, fontsize=fontsize, label_fontsize = label_fontsize, figure_size=figure_size)
        results.append(file_results)        
        plots[file] = fig
        line_plot_args_list[file] = line_plot_args
        results_dict[file] = file_results

    if len(results) == 0:
        raise Exception('The path did not seem to include any correct files, or path was incorrect')

    min_length = np.min(list(map(len, results)))
    np_list = []
    # TODO: Consider what to do with jagged arrays! This can be critical!
    for val in results:
        arr = np.array(val[0:min_length])
        np_list.append(arr)

    all_results = np.vstack(np_list)
    mean_result = np.mean(all_results, axis=0)

    max_value = len(mean_result)
    if end_hour > 0:
        max_value = min(max_value, 3600 * (end_hour - offset_hours))    
    x = np.arange(0,max_value)    
    last_dir = print(dir.split(r'/')[-1])
    avg_title = f'Average for {last_dir}, window length {window_length_in_minutes} minutes'
    average_plot_args = [x, all_results]
    plot_average_plot(*average_plot_args)
    
    # fig = plt.figure()
    # ax = plt.axes()
    # plots['average'] = ax

    # fig.set_size_inches(figure_size[0], figure_size[1])    
   
    # last_dir = print(dir.split(r'/')[-1])
        
    # plt.title(, fontsize = label_fontsize+2)
    # df = pd.DataFrame(all_results).melt()
    # sns.lineplot(ax=ax, y=df.value, x=df.variable/3600,ci='sd')

    # # X axis to hours
    # ax.set_xticks(list(range(math.ceil(x[-1]/3600+1))))
    # ax.set_xticklabels(convert_seconds_to_time([xtick * 3600 for xtick in ax.get_xticks()]))
    
    # ax.yaxis.set_major_formatter(ticker.PercentFormatter(xmax=1))    

    # plt.xlabel('Time',    fontsize = label_fontsize)
    # plt.ylabel('Silence', fontsize = label_fontsize)
    # if fontsize > 0:
    #     plt.yticks(fontsize = fontsize)
    #     plt.xticks(fontsize = fontsize)
        
    # plt.show()
   
    return (all_results, mean_result, results_dict, plots, line_plot_args_list, average_plot_args)

def plot_average_plot(x, all_results, ax=None, title=None, figure_size=(10,10), title_fontsize=12, label_fontsize=12, tick_fontsize=12):
    
    if ax == None:
        
        fig = plt.figure()
        fig.set_size_inches(figure_size[0], figure_size[1])    
        ax = plt.axes()
    else:
        fig = ax.get_figure()

    

        
    plt.title(title, fontsize = title_fontsize)
    df = pd.DataFrame(all_results).melt()
    sns.lineplot(ax=ax, y=df.value, x=df.variable/3600,ci='sd')

    # X axis to hours
    ax.set_xticks(list(range(math.ceil(x[-1]/3600+1))))
    ax.set_xticklabels(convert_seconds_to_time([xtick * 3600 for xtick in ax.get_xticks()]))
    
    ax.yaxis.set_major_formatter(ticker.PercentFormatter(xmax=1))    

    plt.xlabel('Time',    fontsize = label_fontsize)
    plt.ylabel('Silence', fontsize = label_fontsize)
    plt.yticks(fontsize = tick_fontsize)
    plt.xticks(fontsize = tick_fontsize)



from scipy.interpolate import make_interp_spline, BSpline

def create_line_plot(x, y,  title, ylabel, xlabel, offset_hours=6, metadata=None, ax = None, tick_fontsize = 10, title_fontsize = 14, label_fontsize = 12, legend_fontsize=25, legend_marker_size=200, figure_size=(18,10)):
    if ax is None:
        fig = plt.figure()
        fig.set_size_inches(figure_size[0], figure_size[1]) 
        ax = plt.axes()
    plt.sca(ax)
    plt.title(title,fontsize = label_fontsize+2)

    x = x / 3600 # Convert from seconds to hours

    
    # Adds labels (if available) from metadata, i.e. the excel-files
    labels = list()
    if metadata is not None:
        # Predifinied colors!
        colors = [[27,158,119,255], [217,95,2,255], [117,112,179,255]]
        colors = [np.divide(np.array(c),255) for c in colors]


        for meta in metadata.iterrows():
            # TODO: Fix this indexing, clearly rows can more easily be accessed...
            hour = get_hour(meta[1][0])
            labels.append((hour, meta[1][2]))    
 
        # Clearly this could be cleaner, as we do start with a dataframe...
        df = pd.DataFrame(labels,columns=['Hour','Type'])
        df['X'] = df['Hour'] - offset_hours
        df['Y'] = [y[np.abs(x-timestamp).argmin()] for timestamp in df['X']]

        sns.scatterplot(ax=ax, data=df, x="X", y="Y", hue_order=sorted(df['Type'].unique().tolist()), hue="Type", palette=colors, s=200)
        plot_legend = plt.legend(fontsize=legend_fontsize)
        for handle in plot_legend.legendHandles:
            handle._sizes = [legend_marker_size]
    ax.plot(x , y)
    if tick_fontsize > 0:
        plt.yticks(fontsize = tick_fontsize)
        plt.xticks(fontsize = tick_fontsize)
#     ax.set_xticks(list(np.arange(np.rint(np.ceil(x[-1]+1))))
    
    # X axis to hours
    ax.set_xticks(list(range(math.ceil(x[-1]+1))))
    ax.set_xticklabels(convert_seconds_to_time([xtick * 3600 for xtick in ax.get_xticks()]))
    
    ax.yaxis.set_major_formatter(ticker.PercentFormatter(xmax=1))
    
    plt.xlabel(xlabel, fontsize=label_fontsize)
    plt.ylabel(ylabel, fontsize=label_fontsize)
    return y, ax
    


def do_spline_interpolation(data, points=250, k=3):
    x = list(range(0,len(data)))
    x_new =np.linspace(0,len(data), points)
    a_BSpline = make_interp_spline(x, data, k=k)
    y_new = a_BSpline(x_new)
    return x_new, y_new


# This step could use some threading...
def find_average_time(sec, silence_start, silence_end, window_length_in_minutes):
    
    window_length = window_length_in_minutes * 60
    max_value = np.floor(silence_end.max()).astype(int)
    end_time = silence_end.max()

    window_start = np.max([sec - window_length, 0])
    window_end =   np.min([sec + window_length,end_time])
    start_within = (silence_start >= window_start) & (silence_start <= window_end)
    end_within =  (silence_end >= window_start) & (silence_end <= window_end)
    relevant_range_indexes = start_within | end_within
    
   
    if np.sum(relevant_range_indexes) == 0:
        # TODO: Special case! It is possible a long time can be silent!
        for s, e in zip(silence_start, silence_end):
            if s < window_start and e > window_end:
                return 1
        return 0
    
#     print(f'Relevant indexes length: {np.sum(relevant_range_indexes)}')
    relevant_start = silence_start[relevant_range_indexes]
    relevant_start[0] = np.max([window_start, relevant_start[0]])

    relevant_end = silence_end[relevant_range_indexes]
    relevant_end[-1] = np.min([window_end, relevant_end[-1]])
    
    silence_interval = []
    for start, end in zip(relevant_start,relevant_end):
        silence_interval.append(end - start)
#     print(f'Silence interval length: {len(silence_interval)}')
    res = np.sum(silence_interval) / (window_end - window_start)
    return res

from datetime import datetime
def get_hour(time):
    if time is None:
        raise Exception('Error: received empty time!')
    if time is str:
        dt = datetime.strptime(time,'%H:%M:%S')
    else:
        dt = time
    second = dt.hour * 3600 + dt.minute * 60 + dt.second 
    hour = second / 3600
    return hour

import re

def get_date_from_filename(filename:str ):
    yearmatch = re.compile('.*(\d\d\d\d)-(\d\d)-(\d\d).*')
    match_result = yearmatch.match(filename)
    year, month, day = match_result.groups()
    return (year, month, day)


def plot_time_span_file(file: str, window_length_in_minutes=5, radio_programs=None, end_hour=-1, fontsize=-1, label_fontsize=10, offset_hours=6, figure_size=(18,10)):
    df = pd.read_csv(file, sep='\t', header=None)
    file_path = Path(file)
    metadata_df = None
    try:
        metadata_df = radio_programs[file_path.stem]
        print(f'Found metadata for file {file}')
    except: 
        pass
    silence_start = df[0].to_numpy()
    silence_end =   df[1].to_numpy()

    
    try:
        year, month, day = get_date_from_filename(file)
        title = f'Date: {year}-{month}-{day}'
    except:
        print('exception when extracting date... using default title')
        title=file

    max_value = np.floor(silence_end.max()).astype(int)

#     colors = {'Music':'purple', 'News':'green', 'Misc':'red'}
    if end_hour > 0:
        max_value = min(max_value, 3600 * (end_hour - offset_hours))    
    x = np.arange(0,max_value)
    print('Starting calculations...')
    pool=Pool(14)
    new_iter = [(xx, silence_start, silence_end, window_length_in_minutes) for xx in x]
    y = list(pool.starmap(find_average_time, new_iter))
    print('Finished.')
#     y = list(map(lambda time: time_span_functions.find_average_time(time, silence_start, silence_end, window_length_in_minutes), x))
    
    # args and kwargs
    line_plot_args = ((x, y, title, 'Silence', 'Time'), {'offset_hours':offset_hours, 'metadata':metadata_df, 'figure_size':figure_size})

    y, ax = create_line_plot(*line_plot_args[0], **line_plot_args[1])  
    return (y, ax, line_plot_args)


from matplotlib import cm
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
def _get_colors():
    pass
    

# def plot_time_span(silence_start, silence_end, window_length_in_minutes=5, title=None, metadata=None, offset_hours=6, end_hour=-1, fontsize=-1, label_fontsize = 10):
#     max_value = np.floor(silence_end.max()).astype(int)

# #     colors = {'Music':'purple', 'News':'green', 'Misc':'red'}
#     if end_hour > 0:
#         max_value = min(max_value, 3600 * (end_hour - offset_hours))    
#     x = np.arange(0,max_value)
    
#     pool=Pool(14)
#     new_iter = [(xx, silence_start, silence_end, window_length_in_minutes) for xx in x]
#     y = list(pool.starmap(find_average_time, new_iter))
# #     y = list(map(lambda time: time_span_functions.find_average_time(time, silence_start, silence_end, window_length_in_minutes), x))

#     return create_line_plot(x, y, title, 'Silence', 'Time', metadata=metadata) 

#     fig = plt.figure()
#     fig.set_size_inches(18.5, 10.5)
#     ax = plt.axes()

#     plt.title(title,fontsize=label_fontsize+2)

#     x = x / 3600 # Convert from seconds to hours

#     # Adds labels (if available) from metadata, i.e. the excel-files
#     labels = list()
#     if metadata is not None:
#         # Predifinied colors!
#         colors = [[27,158,119,255], [217,95,2,255], [117,112,179,255]]
#         colors = [np.divide(np.array(c),255) for c in colors]


#         for meta in metadata.iterrows():
#             # TODO: Fix this indexing, clearly rows can more easily be accessed...
#             hour = get_hour(meta[1][0])
#             labels.append((hour, meta[1][2]))    
 
#         # Clearly this could be cleaner, as we do start with a dataframe...
#         df = pd.DataFrame(labels,columns=['Hour','Type'])
#         df['X'] = df['Hour'] - offset_hours
#         df['Y'] = [y[np.abs(x-timestamp).argmin()] for timestamp in df['X']]

        

#         sns.scatterplot(ax=ax, data=df, x="X", y="Y", hue_order=sorted(df['Type'].unique().tolist()), hue="Type", palette=colors, s=200)
#         plot_legend = plt.legend(fontsize=25)
#         for handle in plot_legend.legendHandles:
#             handle._sizes = [200]
        
#         # plt.legend(fontsize='x-large', title_fontsize='40')
    
    
    
#     ax.plot(x , y)
#     if fontsize > 0:
#         plt.yticks(fontsize = fontsize)
#         plt.xticks(fontsize = fontsize)
# #     ax.set_xticks(list(np.arange(np.rint(np.ceil(x[-1]+1))))
    
#     # X axis to hours
#     ax.set_xticks(list(range(math.ceil(x[-1]+1))))
#     ax.set_xticklabels(convert_seconds_to_time([xtick * 3600 for xtick in ax.get_xticks()]))
    
    
# #     ax.set_yticklabels([f'{math.floor(ytick * 100)}' for ytick in ax.get_yticks()])
#     ax.yaxis.set_major_formatter(ticker.PercentFormatter(xmax=1))
    
    
    
#     plt.xlabel('Time', fontsize=label_fontsize)
#     plt.ylabel('Silence', fontsize=label_fontsize)

    
#     return y, fig


def convert_seconds_to_time(time, start_hour=6):
    t_hour = pd.to_datetime(time, unit='s').hour
    t_minute = pd.to_datetime(time, unit='s').minute
    times = []
    for h,m in zip(t_hour, t_minute):
        times.append(f'{h+start_hour:02}:{m:02}')
    return times

# Find index of closest hour
def find_tick_indexes(seconds):
    end_hour =  np.floor(np.max(seconds) / 3600).astype(int) + 1
    sec_array = np.array(seconds)
    tick_indexes = []
    for i in range(end_hour):
        res = np.abs(sec_array - i * 3600)
        index = np.argmin(res)
        tick_indexes.append(index)
    tick_indexes.append(len(seconds) - 1)
    return tick_indexes
